/**
 * This file includes code originally from the Pytorch correlation repository:
 * https://github.com/ClementPinard/Pytorch-Correlation-extension、
 * Licensed under the MIT License. See THIRD_PARTY_LICENSES.md for details.
 */

#include <torch/extension.h>
using namespace torch;

#include <vector>

#define WITHIN_BOUNDS(x, y, H, W) (x >= 0 && x < H && y >= 0 && y < W)

template <typename scalar_t>
static void correlate_patch(TensorAccessor<scalar_t, 3> input1, TensorAccessor<scalar_t, 3> input2, scalar_t *dst,
                            int kH, int kW, int dilationH, int dilationW, int u, int v, int shiftU, int shiftV) {
    const int C = input1.size(0);
    const int iH = input1.size(1);
    const int iW = input1.size(2);
    for (int c = 0; c < C; ++c) {
        for (int i = 0; i < kH; ++i) {
            int i1 = u + i * dilationH;
            int i2 = i1 + shiftU;
            if WITHIN_BOUNDS (i1, i2, iH, iH) {
                for (int j = 0; j < kW; ++j) {
                    int j1 = v + j * dilationW;
                    int j2 = j1 + shiftV;
                    if WITHIN_BOUNDS (j1, j2, iW, iW) {
                        scalar_t v1 = input1[c][i1][j1];
                        scalar_t v2 = input2[c][i2][j2];
                        *dst += v1 * v2;
                    }
                }
            }
        }
    }
}

template <typename scalar_t>
static void correlate_patch_grad(TensorAccessor<scalar_t, 3> input1, TensorAccessor<scalar_t, 3> gradInput1,
                                 TensorAccessor<scalar_t, 3> input2, TensorAccessor<scalar_t, 3> gradInput2,
                                 scalar_t gradOutput, int kH, int kW, int dilationH, int dilationW, int u, int v,
                                 int shiftU, int shiftV) {
    const int C = input1.size(0);
    const int iH = input1.size(1);
    const int iW = input1.size(2);

    for (int c = 0; c < C; ++c) {
        for (int i = 0; i < kH; ++i) {
            int i1 = u + i * dilationH;
            int i2 = i1 + shiftU;
            if WITHIN_BOUNDS (i1, i2, iH, iH) {
                for (int j = 0; j < kW; ++j) {
                    int j1 = v + j * dilationW;
                    int j2 = j1 + shiftV;
                    if WITHIN_BOUNDS (j1, j2, iW, iW) {
                        scalar_t v1 = input1[c][i1][j1];
                        scalar_t v2 = input2[c][i2][j2];
                        gradInput2[c][i2][j2] += gradOutput * v1;
                        gradInput1[c][i1][j1] += gradOutput * v2;
                    }
                }
            }
        }
    }
}

torch::Tensor correlation_cpp_forward(torch::Tensor input1, torch::Tensor input2, int kH, int kW, int patchH,
                                      int patchW, int padH, int padW, int dilationH, int dilationW, int dilation_patchH,
                                      int dilation_patchW, int dH, int dW) {
    const auto batch_size = input1.size(0);
    const auto iH = input1.size(2);
    const auto iW = input1.size(3);
    const int patchRadH = (patchH - 1) / 2;
    const int patchRadW = (patchW - 1) / 2;
    const int dilatedKH = (kH - 1) * dilationH + 1;
    const int dilatedKW = (kW - 1) * dilationW + 1;

    const auto oH = (iH + 2 * padH - dilatedKH) / dH + 1;
    const auto oW = (iW + 2 * padW - dilatedKW) / dW + 1;
    auto output = at::zeros({batch_size, patchH, patchW, oH, oW}, input1.options());

    int n, ph, pw, h, w;
#pragma omp parallel for private(n, ph, pw, h, w) collapse(2)
    for (n = 0; n < batch_size; ++n) {
        for (ph = 0; ph < patchH; ++ph) {
            for (pw = 0; pw < patchW; ++pw) {
                AT_DISPATCH_FLOATING_TYPES(
                    input1.scalar_type(), "correlation_forward_cpp", ([&] {
                        auto input1_acc = input1.accessor<scalar_t, 4>();
                        auto input2_acc = input2.accessor<scalar_t, 4>();
                        auto output_acc = output.accessor<scalar_t, 5>();
                        for (h = 0; h < oH; ++h) {
                            for (w = 0; w < oW; ++w) {
                                correlate_patch(input1_acc[n], input2_acc[n], &output_acc[n][ph][pw][h][w], kH, kW,
                                                dilationH, dilationW, -padH + h * dH, -padW + w * dW,
                                                (ph - patchRadH) * dilation_patchH, (pw - patchRadW) * dilation_patchW);
                            }
                        }
                    }));
            }
        }
    }
    return output;
}

std::vector<torch::Tensor> correlation_cpp_backward(torch::Tensor input1, torch::Tensor input2,
                                                    torch::Tensor gradOutput, int kH, int kW, int patchH, int patchW,
                                                    int padH, int padW, int dilationH, int dilationW,
                                                    int dilation_patchH, int dilation_patchW, int dH, int dW) {
    const int batch_size = input1.size(0);
    const int patchRadH = (patchH - 1) / 2;
    const int patchRadW = (patchW - 1) / 2;
    const int oH = gradOutput.size(3);
    const int oW = gradOutput.size(4);

    auto gradInput1 = torch::zeros_like(input1);

    auto gradInput2 = torch::zeros_like(input2);

    int n, ph, pw, h, w;
#pragma omp parallel for private(n, ph, pw, h, w)
    for (n = 0; n < batch_size; ++n) {
        AT_DISPATCH_FLOATING_TYPES(
            input1.scalar_type(), "correlation_backward_cpp", ([&] {
                auto input1_acc = input1.accessor<scalar_t, 4>();
                auto gradInput1_acc = gradInput1.accessor<scalar_t, 4>();
                auto input2_acc = input2.accessor<scalar_t, 4>();
                auto gradInput2_acc = gradInput2.accessor<scalar_t, 4>();
                auto gradOutput_acc = gradOutput.accessor<scalar_t, 5>();

                for (ph = 0; ph < patchH; ++ph) {
                    for (pw = 0; pw < patchW; ++pw) {
                        for (h = 0; h < oH; ++h) {
                            for (w = 0; w < oW; ++w) {
                                correlate_patch_grad(input1_acc[n], gradInput1_acc[n], input2_acc[n], gradInput2_acc[n],
                                                     gradOutput_acc[n][ph][pw][h][w], kH, kW, dilationH, dilationW,
                                                     -padH + h * dH, -padW + w * dW, (ph - patchRadH) * dilation_patchH,
                                                     (pw - patchRadW) * dilation_patchW);
                            }
                        }
                    }
                }
            }));
    }

    return {gradInput1, gradInput2};
}
